#define patchSide 8

#define WGS_W 64
#define ITEMS WGS_W

#define N 16
#define NSHIFT 4

__kernel __attribute__((reqd_work_group_size(WGS_W, 1, 1)))
void sort(
          __global unsigned*     offsets256,
          __global float*        dists256,
          __global unsigned*     offsets16,
          __global unsigned*     w_ind,
          __global unsigned*     h_ind,
          const int              w_ind_size,
          const int              h_ind_size,
          const int              width
          )
{
//    int i = get_global_id(0)*256;
//    int j = get_global_id(0)*16;

    int local_idx = get_local_id(0);
    
    if(get_global_id(0) >= w_ind_size*h_ind_size)
        return;
    
    int ind_i = get_global_id(0)%w_ind_size;
    int ind_j = get_global_id(0)/w_ind_size;
    int offsetOrg = h_ind[ind_j]*width + w_ind[ind_i];
    
    typedef struct dist_t { unsigned offset; float dist; } dist_t;
    
    __local dist_t buf[WGS_W][N];
    __local dist_t* _dists = buf[local_idx];

//    __local dist_t _l_dists[ITEMS*256];
//
//    for(int k = 0; k < 256; k++)
//    {
//        int offsSrc = get_group_id(0)*ITEMS*256 + k*ITEMS + local_idx;
//        int offsDst = k*ITEMS + local_idx;
//        
//        _l_dists[offsDst].offset = offsets256[offsSrc];
//        _l_dists[offsDst].dist = dists256[offsSrc];
//    }
//    
//    barrier(CLK_LOCAL_MEM_FENCE);
    
    // insert original patch first
#pragma unroll
//    for(int n = 0; n < N; n++)
    int n = 0;
    {
        _dists[n].offset = offsetOrg;
        _dists[n].dist = 0;
    }
    int _outSize = 1;
    //

    for(int k = 0; k < 256; k++)
    {
//        float dist = _l_dists[local_idx*256 + k].dist;
//        int offset = _l_dists[local_idx*256 + k].offset;

        float dist = dists256[get_global_id(0)*256 + k];
        int offset = offsets256[get_global_id(0)*256 + k];

        // insert into the offsets table
        //if(dist < threshold)
        {
            unsigned imin = 1;
            unsigned imax = _outSize;
                
            while (imin < imax)
            {
                int imid = (imin + imax)>>1;
                
                if (_dists[imid].dist < dist)
                    imin = imid + 1;
                else
                    imax = imid;
            }
                
            int insertPos = imin;
            if(insertPos < N)
            {
                _outSize += (_outSize < N);
                for(int k = _outSize-1; k > insertPos; k--)
                    _dists[k] = _dists[k-1];
                _dists[insertPos].offset = offset;
                _dists[insertPos].dist = dist;
            }
        }
        //
    }
    
    int offs = (ind_j*w_ind_size + ind_i)*16;
    __global unsigned* g_offsets16 = offsets16 + offs;
    
#pragma unroll
    for(int i = 0; i < N; i++)
        g_offsets16[i] = _dists[i].offset;
}
